"""
Created on Wed Dec  8 15:20:07 2021.

Optimal Policy Trees: Tree Functions - Python implementation

Can be used under Creative Commons Licence CC BY-SA
Michael Lechner, SEW, University of St. Gallen, Switzerland

# -*- coding: utf-8 -*-
"""
import random
import math
from concurrent import futures

import pandas as pd
import numpy as np
import ray
import scipy.stats as sct
from numba import njit

from mcf import optp_tree_add_functions as optp_ta
from mcf import general_purpose as gp


def get_values_ordered(single_x_np, ps_np_diff, values, no_of_values,
                       with_numba=True):
    """
    Sort values according policy score differences: NUR Durchlauferhitzer.

    Parameters
    ----------
    single_x_np : 1D numpy array. Covariate.
    ps_np_diff : 2 D numpy array. Policy scores as difference.
    values : 1D numpy array. All unique values of x.
    no_of_values : Int. #  of Unique values of x.
    with_numba : Boolean. Use numba module. Default is True.

    Returns
    -------
    values_sorted : 2D numpy array. Sorted values.

    """
    if with_numba:
        values_sorted, no_of_ps = get_values_ordered_numba(
            single_x_np, ps_np_diff, values, no_of_values)
    else:
        values_sorted, no_of_ps = get_values_ordered_no_numba(
            single_x_np, ps_np_diff, values, no_of_values)
    return values_sorted, no_of_ps


def get_values_ordered_no_numba(single_x_np, ps_np_diff, values, no_of_values):
    """
    Sort values according policy score differences.

    Parameters
    ----------
    single_x_np : 1D numpy array. Covariate.
    ps_np_diff : 2 D numpy array. Policy scores as difference.
    values : 1D numpy array. All unique values of x.
    no_of_values : Int. #  of Unique values of x.

    Returns
    -------
    values_sorted : 2D numpy array. Sorted values.

    """
    no_of_ps = np.size(ps_np_diff, axis=1)
    mean_y_by_values = np.empty((no_of_values, no_of_ps))
    for i, val in enumerate(values):
        ps_group = ps_np_diff[np.where(single_x_np == val)]
        mean_y_by_values[i, :] = np.transpose(np.mean(ps_group, axis=0))
    indices = np.empty((no_of_values, no_of_ps))
    values_sorted = np.empty((no_of_values, no_of_ps))
    for j in range(no_of_ps):
        indices = np.argsort(mean_y_by_values[:, j])
        values_sorted[:, j] = values[indices]
    return values_sorted, no_of_ps


@njit
def get_values_ordered_numba(single_x_np, ps_np_diff, values, no_of_values):
    """
    Sort values according policy score differences.

    Parameters
    ----------
    single_x_np : 1D numpy array. Covariate.
    ps_np_diff : 2 D numpy array. Policy scores as difference.
    values : 1D numpy array. All unique values of x.
    no_of_values : Int. #  of Unique values of x.

    Returns
    -------
    values_sorted : 2D numpy array. Sorted values.

    """
    no_of_ps = np.shape(ps_np_diff)[1]  # wg Numba
    mean_y_by_values = np.empty((no_of_values, no_of_ps))
    for i, val in enumerate(values):
        ps_group = ps_np_diff[np.where(single_x_np == val)]
        for j in range(no_of_ps):  # wg numba
            mean_y_by_values[i, j] = np.mean(ps_group[:, j])
    indices = np.empty((no_of_values, no_of_ps))
    values_sorted = np.empty((no_of_values, no_of_ps))
    for j in range(no_of_ps):
        indices = np.argsort(mean_y_by_values[:, j])
        values_sorted[:, j] = values[indices]
    return values_sorted, no_of_ps


def get_values_cont_x(data_vector, no_of_evalupoints, with_numba=True):
    """Get cut-off points for tree splitting for continuous variables.

    Parameters
    ----------
    data_vector : Numpy-1D array. Sorted vector
    no_of_evalupoints : Int.   c_dict['no_of_evalupoints']
    with_numba : Boolean. Use numba module. Default is True.

    Returns
    -------
    Numpy 1D-array. Sorted cut-off-points

    """
    if with_numba:
        data_vector_new = get_values_cont_x_numba(data_vector,
                                                  no_of_evalupoints)
    else:
        data_vector_new = get_values_cont_x_no_numba(data_vector,
                                                     no_of_evalupoints)
    return data_vector_new


@njit
def get_values_cont_x_numba(data_vector, no_of_evalupoints):
    """Get cut-off points for tree splitting for continuous variables.

    Parameters
    ----------
    data_vector : Numpy-1D array. Sorted vector
    no_of_evalupoints : Int.   c_dict['no_of_evalupoints']

    Returns
    -------
    Numpy 1D-array. Sorted cut-off-points

    """
    data_vector = np.unique(data_vector)
    obs = len(data_vector)
    if no_of_evalupoints > (obs - 10):
        data_vector_new = data_vector
    else:
        indices = np.linspace(obs / no_of_evalupoints, obs,
                              no_of_evalupoints+1)
        data_vector_new = np.empty(no_of_evalupoints)
        for i in range(no_of_evalupoints):
            indices_i = np.uint32(indices[i])
            data_vector_new[i] = data_vector[indices_i]
    return data_vector_new


def get_values_cont_x_no_numba(data_vector, no_of_evalupoints):
    """Get cut-off points for tree splitting for continuous variables.

       No longer used; only kept if no_numba version would be needed

    Parameters
    ----------
    sorted_data : Numpy-1D array. Sorted vector
    no_of_evalupoints : Int.   c_dict['no_of_evalupoints']

    Returns
    -------
    Numpy 1D-array. Sorted cut-off-points

    """
    data_vector = np.unique(data_vector)
    obs = len(data_vector)
    if no_of_evalupoints > (obs - 10):
        return data_vector
    indices = np.uint32(np.linspace(obs / no_of_evalupoints, obs,
                                    no_of_evalupoints, endpoint=False))
    return data_vector[indices]


def merge_trees(tree_l, tree_r, name_x_m, type_x_m, val_x, treedepth):
    """Merge trees and add new split.

    0: Node identifier (INT: 0-...)
    1: Parent knot
    2: Child node left
    3: Child node right
    4: Type of node (1: Terminal node, no further splits
                    0: previous node that lead already to further splits)
    5: String: Name of variable used for decision of next split
    6: x_type of variable (policy categorisation, maybe different from MCF)
    7: If x_type = 'unordered': Set of values that goes to left daughter
    7: If x_type = 0: Cut-off value (larger goes to right daughter)
    8: List of Treatment state for both daughters [left, right]

    Parameters
    ----------
    tree_l : List of lists. Left tree.
    tree_r : List of lists. Right tree.
    name_x_m : String. Name of variables used for splitting.
    type_x_m : String. Type of variables used for splitting.
    val_x : Float, Int, or set of Int. Values used for splitting.
    treedepth : Int. Current level of tree. 1: final level.

    Returns
    -------
    new_tree : List of lists. The merged trees.

    """
    leaf = [None] * 9
    leaf[0], leaf[1] = random.randrange(100000), None
    leaf[5], leaf[6], leaf[7] = name_x_m, type_x_m, val_x
    if treedepth == 2:  # Final split (defines 2 final leaves)
        leaf[2], leaf[3], leaf[4] = None, None, 1
        leaf[8] = [tree_l, tree_r]  # For 1st tree --> treatment states
        new_tree = [leaf]
    else:
        leaf[2], leaf[3], leaf[4] = tree_l[0][0], tree_r[0][0], 0
        tree_l[0][1], tree_r[0][1] = leaf[0], leaf[0]
        new_tree = [None] * (1 + 2 * len(tree_l))
        new_tree[0] = leaf
        i = 1
        for i_l in tree_l:
            new_tree[i] = i_l
            i += 1
        for i_r in tree_r:
            new_tree[i] = i_r
            i += 1
    return new_tree


def evaluate_leaf(data_ps, c_dict):
    """Evaluate final value of leaf taking restriction into account.

    Parameters
    ----------
    data_ps : Numpy array. Policy scores.
    c_dict : Dict. Controls.

    Returns
    -------
    treat_ind: Int. Index of treatment.
    reward: Int. Value of leaf.
    no_per_treat: Numpy 1D-array of int.

    """
    if c_dict['with_numba']:
        indi, reward_by_treat, obs_all = evaluate_leaf_numba(
            data_ps, c_dict['no_of_treatments'], c_dict['max_by_treat'],
            c_dict['restricted'], c_dict['costs_of_treat'])
    else:
        indi, reward_by_treat, obs_all = evaluate_leaf_no_numba(data_ps,
                                                                c_dict)
    return indi, reward_by_treat, obs_all


@njit
def evaluate_leaf_numba(data_ps, no_of_treatments, max_by_treat, restricted,
                        costs_of_treat):
    """Evaluate final value of leaf taking restriction into account.

    Parameters
    ----------
    data_ps : Numpy array. Policy scores.
    ...

    Returns
    -------
    treat_ind: Int. Index of treatment.
    reward: Int. Value of leaf.
    no_per_treat: Numpy 1D-array of int.

    """
    obs_all, obs = np.zeros(no_of_treatments), len(data_ps)
    indi = np.arange(no_of_treatments)
    if restricted:
        diff_obs = obs - max_by_treat
        treat_not_ok = diff_obs > 0.999
        if np.any(treat_not_ok):
            treat_ok = np.invert(treat_not_ok)
            data_ps_tmp = data_ps[:, treat_ok]
            if data_ps_tmp.size == 0:
                idx = np.argmin(diff_obs)
                treat_ok[idx] = True
                data_ps = data_ps[:, treat_ok]
            else:
                data_ps = data_ps_tmp
            indi = indi[treat_ok]      # Remove obs that violate restriction
            costs_of_treat = costs_of_treat[indi]
    reward_by_treat = data_ps.sum(axis=0) - costs_of_treat * obs
    max_i = np.argmax(reward_by_treat)
    obs_all[indi[max_i]] = obs
    return indi[max_i], reward_by_treat[max_i], obs_all


def evaluate_leaf_no_numba(data_ps, c_dict):
    """Evaluate final value of leaf taking restriction into account.

    Parameters
    ----------
    data_ps : Numpy array. Policy scores.
    max_per_treat : Tuple of int. Maximum number of obs in treatment.

    Returns
    -------
    treat_ind: Int. Index of treatment.
    reward: Int. Value of leaf.
    no_per_treat: Numpy 1D-array of int.

    """
    obs_all, obs = np.zeros(c_dict['no_of_treatments']), len(data_ps)
    indi = np.arange(c_dict['no_of_treatments'])
    if c_dict['restricted']:
        diff_obs = obs - c_dict['max_by_treat']
        treat_not_ok = diff_obs > 0.999
        if np.any(treat_not_ok):
            treat_ok = np.invert(treat_not_ok)
            data_ps_tmp = data_ps[:, treat_ok]
            if data_ps_tmp.size == 0:
                idx = np.argmin(diff_obs)
                treat_ok[idx] = True
                data_ps = data_ps[:, treat_ok]
            else:
                data_ps = data_ps_tmp
            indi = indi[treat_ok]      # Remove obs that violate restriction
            costs_of_treat = c_dict['costs_of_treat'][indi]
        else:
            costs_of_treat = c_dict['costs_of_treat']
    else:
        costs_of_treat = c_dict['costs_of_treat']
    reward_by_treat = data_ps.sum(axis=0) - costs_of_treat * obs
    max_i = np.argmax(reward_by_treat)
    obs_all[indi[max_i]] = obs
    return indi[max_i], reward_by_treat[max_i], obs_all


def seq_tree_search(data_ps, data_ps_diff, data_x, name_x, type_x, values_x,
                    c_dict):
    """Build sequential tree.

    Parameters
    ----------
    data_ps : Numpy array. Policy scores.
    data_ps_diff : Numpy array. Policy scores as differences to cat 0.
    data_x : Numpy array. Policy variables.
    ind_sort_x : Numpy array. Sorted Indices with respect to cols. of x
    ind_leaf: Numpy array. Remaining data in leaf.
    name_x : List of strings. Name of policy variables.
    type_x : List of strings. Type of policy variable.
    values_x : List of sets. Values of x for non-continuous variables.
    c_dict : Dict. Parameters.
    treedepth : Int. Current depth of tree.
    no_further_splits : Boolean.
        Further splits do not matter. Take next (1st) split as final. Default
        is False.

    Returns
    -------
    tree : List of lists. Current tree.
    reward : Float. Total reward that comes from this tree.
    no_by_treat : List of int. Number of treated by treatment state (0-...)

    Content of tree for each node:
    0: Node identifier (INT: 0-...)
    1: Parent knot
    2: Child node left
    3: Child node right
    4: Type of node (2: Active -> will be further splitted or made terminal
                    1: Terminal node, no further splits
                    0: previous node that lead already to further splits)
    5: String: Name of variable used for decision of next split
    6: x_type of variable (policy categorisation, maybe different from MCF)
    7: If x_type = 'unordered': Set of values that goes to left daughter
    7: If x_type = 0: Cut-off value (larger goes to right daughter)
    8: List of Treatment state for both daughters [left, right]
    9: Level (0-c_dict['st_depth'])
    10: Indices of data: Numpy series
    11: Treatment of leaf

    """
    def add_leaves_to_tree(tree, best_treat_l, best_treat_r, best_name_x,
                           best_type_x, best_val_x, best_left, best_right,
                           indices, level, final, parent_leaf):
        # Check if any split, if not remove last leaf
        if best_treat_l is None or best_treat_r is None:   # status --> final
            index_of_grandparent = index_from_leaf_id(tree, parent_leaf[1])
            tree[index_of_grandparent][4] = 1
            index_of_parent_l = index_from_leaf_id(
                tree, tree[index_of_grandparent][2])
            index_of_parent_r = index_from_leaf_id(
                tree, tree[index_of_grandparent][3])
            del tree[index_of_parent_l]
            del tree[index_of_parent_r]
            return tree
        # Create and assign to left & right daughter
        daughter_left, daughter_right = [None] * 12, [None] * 12
        daughter_left[0] = random.randrange(100000)
        daughter_right[0] = random.randrange(100000)
        daughter_left[1], daughter_right[1] = parent_leaf[0], parent_leaf[0]
        daughter_left[4], daughter_right[4] = 2, 2
        daughter_left[9], daughter_right[9] = level + 1, level + 1
        daughter_left[10] = indices[best_left]
        daughter_right[10] = indices[best_right]
        daughter_left[11],  daughter_right[11] = best_treat_l, best_treat_r
        # Change values in parent leaf
        parent_leaf[2], parent_leaf[3] = daughter_left[0], daughter_right[0]
        parent_leaf[4] = 1 if final else 0
        parent_leaf[5], parent_leaf[6] = best_name_x, best_type_x
        parent_leaf[7] = best_val_x
        parent_leaf[8] = [best_treat_l, best_treat_r]
        # Exchange the parent leaf in the tree
        index_of_parent = index_from_leaf_id(tree, parent_leaf[0])
        tree[index_of_parent] = parent_leaf.copy()
        if not final:
            tree.append(daughter_left)
            tree.append(daughter_right)
        return tree

    def index_from_leaf_id(tree, leaf_id):
        for leaf_no, leaf in enumerate(tree):
            if leaf[0] == leaf_id:
                return leaf_no
        raise Exception('Leaf_id not found in tree.')

    def list_of_leaves_f(level, tree):
        list_of_leaves = [leaf for leaf in tree if (leaf[9] == level
                                                    and leaf[4] == 2)]
        assert list_of_leaves, f'Level: {level}. No leaves to investigate.'
        return list_of_leaves

    def initiale_node_table(obs):
        leaf = [None] * 12
        leaf[0] = random.randrange(100000)
        leaf[1], leaf[4], leaf[9], leaf[10] = None, 2, 0, np.arange(obs)
        return [leaf]

    def get_leaf_data(data_x, data_ps_diff, data_ps, current_leaf):
        indices_l = current_leaf[10]
        return (data_x[indices_l], data_ps_diff[indices_l], data_ps[indices_l],
                current_leaf[10])

    tree, no_of_x = initiale_node_table(len(data_ps)), len(type_x)
    for level in range(c_dict['st_depth']):
        min_leaf_size = c_dict['st_min_leaf_size'] * 2**(
            c_dict['st_depth'] - level)
        list_of_leaves = list_of_leaves_f(level, tree)
        final = (c_dict['st_depth'] - (level + 1)) == 0
        for parent_leaf in list_of_leaves:
            reward = -math.inf  # minus infinity
            (data_x_leaf, data_ps_diff_leaf, data_ps_leaf, indices_leaf
             ) = get_leaf_data(data_x, data_ps_diff, data_ps, parent_leaf)
            obs_leaf = len(indices_leaf)
            best_treat_l = best_treat_r = best_name_x = best_type_x = None
            best_val_x = best_left = best_right = None
            for m_i in range(no_of_x):
                if type_x[m_i] == 'cont':
                    values_x_to_check = get_values_cont_x(
                        data_x_leaf[:, m_i], obs_leaf,
                        with_numba=c_dict['with_numba'])
                elif type_x[m_i] == 'disc':
                    values_x_to_check = values_x[m_i][:]
                else:
                    values_x_to_check = optp_ta.combinations_categorical(
                            data_x_leaf[:, m_i], data_ps_diff_leaf, c_dict)
                for val_x in values_x_to_check:
                    if type_x[m_i] == 'unord':
                        left = np.isin(data_x_leaf[:, m_i], val_x)
                    else:
                        left = data_x_leaf[:, m_i] <= (val_x + 1e-15)
                    obs_left = np.count_nonzero(left)
                    if not (min_leaf_size <= obs_left
                            <= (len(left) - min_leaf_size)):
                        continue
                    right = np.invert(left)
                    treat_l, reward_l, no_by_treat_l = evaluate_leaf(
                        data_ps_leaf[left], c_dict)
                    treat_r, reward_r, no_by_treat_r = evaluate_leaf(
                        data_ps_leaf[right], c_dict)
                    if reward_r + reward_l > reward:
                        reward = reward_l + reward_r
                        no_by_treat = no_by_treat_l + no_by_treat_r
                        best_treat_l, best_treat_r = treat_l, treat_r
                        best_left,  best_right = left.copy(), right.copy()
                        best_name_x, best_type_x = name_x[m_i], type_x[m_i]
                        best_val_x = val_x
            tree = add_leaves_to_tree(
                tree, best_treat_l, best_treat_r, best_name_x, best_type_x,
                best_val_x, best_left, best_right, indices_leaf, level, final,
                parent_leaf)
    return tree, reward, no_by_treat


def tree_search(data_ps, data_ps_diff, data_x, name_x, type_x, values_x,
                c_dict, treedepth, no_further_splits=False):
    """Build tree.

    Parameters
    ----------
    data_ps : Numpy array. Policy scores.
    data_ps_diff : Numpy array. Policy scores as differences.
    data_x : Numpy array. Policy variables.
    ind_sort_x : Numpy array. Sorted Indices with respect to cols. of x
    ind_leaf: Numpy array. Remaining data in leaf.
    name_x : List of strings. Name of policy variables.
    type_x : List of strings. Type of policy variable.
    values_x : List of sets. Values of x for non-continuous variables.
    c_dict : Dict. Parameters.
    treedepth : Int. Current depth of tree.
    no_further_splits : Boolean.
        Further splits do not matter. Take next (1st) split as final. Default
        is False.

    Returns
    -------
    tree : List of lists. Current tree.
    reward : Float. Total reward that comes from this tree.
    no_by_treat : List of int. Number of treated by treatment state (0-...)

    """
    if treedepth == 1:  # Evaluate tree
        tree, reward, no_by_treat = evaluate_leaf(data_ps, c_dict)
    else:
        if not no_further_splits and (treedepth < c_dict['ft_depth']):
            no_further_splits = only_1st_tree_fct3(data_ps, c_dict)
        min_leaf_size = c_dict['ft_min_leaf_size'] * 2**(treedepth - 2)
        no_of_x, reward = len(type_x), -math.inf
        tree = no_by_treat = None
        for m_i in range(no_of_x):
            if c_dict['with_output']:
                if treedepth == c_dict['ft_depth']:
                    print(f'{name_x[m_i]:20s}  {m_i / no_of_x * 100:4.1f}%',
                          ' of variables completed')
            if type_x[m_i] == 'cont':
                values_x_to_check = get_values_cont_x(
                    data_x[:, m_i], c_dict['ft_no_of_evalupoints'],
                    with_numba=c_dict['with_numba'])
            elif type_x[m_i] == 'disc':
                values_x_to_check = values_x[m_i][:]
            else:
                if treedepth < c_dict['ft_depth']:
                    values_x_to_check = optp_ta.combinations_categorical(
                        data_x[:, m_i], data_ps_diff, c_dict)
                else:
                    values_x_to_check = values_x[m_i][:]
            for val_x in values_x_to_check:
                if type_x[m_i] == 'unord':
                    left = np.isin(data_x[:, m_i], val_x)
                else:
                    left = data_x[:, m_i] <= (val_x + 1e-15)
                obs_left = np.count_nonzero(left)
                if not (min_leaf_size <= obs_left
                        <= (len(left) - min_leaf_size)):
                    continue
                right = np.invert(left)
                tree_l, reward_l, no_by_treat_l = tree_search(
                    data_ps[left, :], data_ps_diff[left, :], data_x[left, :],
                    name_x, type_x, values_x, c_dict, treedepth - 1,
                    no_further_splits)
                tree_r, reward_r, no_by_treat_r = tree_search(
                    data_ps[right, :], data_ps_diff[right, :],
                    data_x[right, :], name_x, type_x, values_x, c_dict,
                    treedepth - 1, no_further_splits)
                if c_dict['restricted']:
                    reward_l, reward_r = adjust_reward(
                        no_by_treat_l, no_by_treat_r, reward_l, reward_r,
                        c_dict)
                if reward_l + reward_r > reward:
                    reward = reward_l + reward_r
                    no_by_treat = no_by_treat_l + no_by_treat_r
                    tree = merge_trees(tree_l, tree_r, name_x[m_i],
                                       type_x[m_i], val_x, treedepth)
                if no_further_splits:
                    return tree, reward, no_by_treat
    return tree, reward, no_by_treat


def only_1st_tree_fct(data_ps, c_dict):
    """Find out if further splits make any sense. NOT USED."""
    no_further_splitting = True
    for i, _ in enumerate(data_ps):
        if i == 0:
            ref_val = np.argmax(data_ps[i]-c_dict['costs_of_treat'])
        else:
            opt_treat = np.argmax(data_ps[i]-c_dict['costs_of_treat'])
            if ref_val != opt_treat:
                no_further_splitting = False
                break
    return no_further_splitting


def only_1st_tree_fct2(data_ps, c_dict):
    """Find out if further splits make any sense.NOT USED."""
    no_further_splitting = True
    opt_treat = np.argmax(data_ps-c_dict['costs_of_treat'], axis=1)
    no_further_splitting = np.all(opt_treat == opt_treat[0])
    return no_further_splitting


def only_1st_tree_fct3(data_ps, c_dict):
    """Find out if further splits make any sense."""
    data = data_ps-c_dict['costs_of_treat']
    no_further_splitting = all_same_max_numba(data)
    return no_further_splitting


@njit
def all_same_max_numba(data):
    """Check same categies have max."""
    no_further_splitting = True
    for i in range(len(data)):
        if i == 0:
            ref_val = np.argmax(data[i, :])
        else:
            opt_treat = np.argmax(data[i, :])
            if ref_val != opt_treat:
                no_further_splitting = False
                break
    return no_further_splitting


def tree_search_multip_single(data_ps, data_ps_diff, data_x, name_x, type_x,
                              values_x, c_dict, treedepth, m_i):
    """Build tree. Only first level. For multiprocessing only.

    Parameters
    ----------
    data_ps : Numpy array. Policy scores.
    data_ps_diff : Numpy array. Policy scores relative to reference category.
    data_x : Numpy array. Policy variables.
    ind_sort_x : Numpy array. Sorted Indices with respect to cols. of x
    ind_leaf: Numpy array. Remaining data in leaf.
    name_x : List of strings. Name of policy variables.
    type_x : List of strings. Type of policy variable.
    values_x : List of sets. Values of x for non-continuous variables.
    c_dict : Dict. Parameters.
    treedepth : Current depth of tree.

    Returns
    -------
    tree : List of lists. Current tree.
    reward : Float. Total reward that comes from this tree.
    no_by_treat : List of int. Number of treated by treatment state (0-...)

    """
    assert treedepth != 1, 'This should not happen in Multiprocessing.'
    reward, tree, no_by_treat = -math.inf, None, None
    if type_x[m_i] == 'cont':
        values_x_to_check = get_values_cont_x(
            data_x[:, m_i], c_dict['ft_no_of_evalupoints'],
            with_numba=c_dict['with_numba'])
    elif type_x[m_i] == 'disc':
        values_x_to_check = values_x[m_i][:]
    else:
        if treedepth < c_dict['ft_depth']:
            values_x_to_check = optp_ta.combinations_categorical(
                data_x[:, m_i], data_ps_diff, c_dict)
        else:
            values_x_to_check = values_x[m_i][:]
    for val_x in values_x_to_check:
        if type_x[m_i] == 'unord':
            left = np.isin(data_x[:, m_i], val_x)
        else:
            left = data_x[:, m_i] <= (val_x + 1e-15)
        obs_left = np.count_nonzero(left)
        if not (c_dict['ft_min_leaf_size'] <= obs_left
                <= (len(left)-c_dict['ft_min_leaf_size'])):
            continue
        right = np.invert(left)
        tree_l, reward_l, no_by_treat_l = tree_search(
            data_ps[left, :], data_ps_diff[left, :], data_x[left, :],
            name_x, type_x, values_x, c_dict, treedepth - 1)
        tree_r, reward_r, no_by_treat_r = tree_search(
            data_ps[right, :], data_ps_diff[right, :], data_x[right, :],
            name_x, type_x, values_x, c_dict, treedepth - 1)
        if c_dict['restricted']:
            reward_l, reward_r = adjust_reward(
                no_by_treat_l, no_by_treat_r, reward_l, reward_r, c_dict)
        if reward_l + reward_r > reward:
            reward = reward_l + reward_r
            no_by_treat = no_by_treat_l + no_by_treat_r
            tree = merge_trees(tree_l, tree_r, name_x[m_i],
                               type_x[m_i], val_x, treedepth)
    return tree, reward, no_by_treat


def adjust_reward(no_by_treat_l, no_by_treat_r, reward_l, reward_r, c_dict):
    """Adjust rewards if restrictions are violated.

    Parameters
    ----------
    no_by_treat_l : Numpy array.
    no_by_treat_r : Numpy array.
    reward_l : Float.
    reward_r : Float.
    c_dict : Dict. Parameter.

    Returns
    -------
    reward_l : Numpy array.
    reward_r : Numpy array.

    """
    if c_dict['with_numba']:
        reward_l, reward_r = adjust_reward_numba(
            no_by_treat_l, no_by_treat_r, reward_l, reward_r,
            c_dict['max_by_treat'])
    else:
        reward_l, reward_r = adjust_reward_no_numba(
            no_by_treat_l, no_by_treat_r, reward_l, reward_r,
            c_dict['max_by_treat'])
    return reward_l, reward_r


@njit
def adjust_reward_numba(no_by_treat_l, no_by_treat_r, reward_l, reward_r,
                        max_by_treat):
    """Adjust rewards if restrictions are violated.

    Parameters
    ----------
    no_by_treat_l : Numpy array.
    no_by_treat_r : Numpy array.
    reward_l : Float.
    reward_r : Float.
    c_dict : Dict. Parameter.

    Returns
    -------
    reward_l : Numpy array.
    reward_r : Numpy array.

    """
    if not ((no_by_treat_l is None) or (no_by_treat_r is None)):
        no_by_treat = no_by_treat_l + no_by_treat_r
        violations = no_by_treat > max_by_treat
        if np.any(violations):
            diff_max = ((no_by_treat - max_by_treat) / max_by_treat).max()
            diff = min(diff_max, 1)
            reward_l = reward_l - diff * np.abs(reward_l)
            reward_r = reward_r - diff * np.abs(reward_r)
    return reward_l, reward_r


def adjust_reward_no_numba(no_by_treat_l, no_by_treat_r, reward_l, reward_r,
                           max_by_treat):
    """Adjust rewards if restrictions are violated.

    Parameters
    ----------
    no_by_treat_l : Numpy array.
    no_by_treat_r : Numpy array.
    reward_l : Float.
    reward_r : Float.
    max_by_treat : List of Int.

    Returns
    -------
    reward_l : Numpy array.
    reward_r : Numpy array.

    """
    if (no_by_treat_l is None) or (no_by_treat_r is None):
        return reward_l, reward_r
    no_by_treat = no_by_treat_l + no_by_treat_r
    if np.any(no_by_treat > max_by_treat):
        diff_max = ((no_by_treat - max_by_treat) / max_by_treat).max()
        diff = min(diff_max, 1)
        reward_l = reward_l - diff * np.abs(reward_l)
        reward_r = reward_r - diff * np.abs(reward_r)
    return reward_l, reward_r


def adjust_policy_score(datafile_name, c_dict, v_dict):
    """
    Adjust the policy score to account for insignificant effects.

    Parameters
    ----------
    datafile_name (str): Name of data file.
    c_dict (dict): Dictionary with controls.
    v_dict (dict): Dictionary with variables.

    Returns
    -------
    data_ps (numpy array, N x no of treatment): Policy scores.

    """
    data_df = pd.read_csv(datafile_name)
    data_ps = data_df[v_dict['polscore_name']].to_numpy(copy=True)
    data_ps_vs_0 = data_df[v_dict['effect_vs_0']].to_numpy()
    data_ps_vs_0_se = data_df[v_dict['effect_vs_0_se']].to_numpy()
    p_val = sct.t.sf(np.abs(data_ps_vs_0 / data_ps_vs_0_se), 1000000)  # 1sided
    no_of_recoded = 0
    for i in range(len(data_ps)):
        for idx, _ in enumerate(v_dict['effect_vs_0']):
            if (data_ps_vs_0[i, idx] > 0) and (
                    p_val[i, idx] > c_dict['sig_level_vs_0']):
                data_ps[i, idx+1] = data_ps[i, 0] - 1e-8  # a bit smaller
                no_of_recoded += 1
    if c_dict['with_output']:
        print()
        print(f'{no_of_recoded:5d} policy scores recoded')
    return data_ps, data_df


def sequential_tree_proc(datafile_name, x_type, x_value, v_dict, c_dict):
    """Build sequential policy tree.

    This function is for multiprocessing only.

    Parameters
    ----------
    datafile_name: String.
    x_type : Dict. Type information of variables.
    x_value : Dict. Value information of variables.
    v_dict: Dict. Variables.
    c_dict : Dict. Parameters.

    Returns
    -------
    optimal_tree :  List of lists.
    optimal_reward: Float. Rewards of tree.
    obs_total: Int. Number of observations.

    """
    if c_dict['with_output']:
        print('Building sequential policy / decision tree')
        print('No multiprocessing for sequential tree building (not yet).')
    (data_x, data_ps, data_ps_diff, name_x, type_x, values_x
     ) = optp_ta.prepare_data_for_tree_builddata(datafile_name, c_dict, v_dict,
                                                 x_type, x_value)
    seq_tree, seq_reward, obs_total = seq_tree_search(
        data_ps, data_ps_diff, data_x, name_x, type_x, values_x, c_dict)
    return seq_tree, seq_reward, obs_total


def optimal_tree_proc(datafile_name, x_type, x_value, v_dict, c_dict):
    """Build optimal policy tree.

    This function is for multiprocessing only.

    Parameters
    ----------
    datafile_name: String.
    x_type : Dict. Type information of variables.
    x_value : Dict. Value information of variables.
    v_dict: Dict. Variables.
    c_dict : Dict. Parameters.

    Returns
    -------
    optimal_tree :  List of lists.
    optimal_reward: Float. Rewards of tree.
    obs_total: Int. Number of observations.

    """
    if c_dict['with_output']:
        print('\nBuilding optimal policy / decision tree')
    (data_x, data_ps, data_ps_diff, name_x, type_x, values_x
     ) = optp_ta. prepare_data_for_tree_builddata(datafile_name, c_dict,
                                                  v_dict, x_type, x_value)
    optimal_tree, x_trees = None, []
    if c_dict['parallel']:
        maxworkers = c_dict['no_parallel']
        if c_dict['mp_with_ray']:
            if not ray.is_initialized():
                ray.init(num_cpus=maxworkers, include_dashboard=False)
            data_x_ref = ray.put(data_x)
            data_ps_ref = ray.put(data_ps)
            data_ps_diff_ref = ray.put(data_ps_diff)
            still_running = [ray_tree_search_multip_single.remote(
                data_ps_ref, data_ps_diff_ref, data_x_ref, name_x, type_x,
                values_x, c_dict, c_dict['ft_depth'], m_i)
                for m_i in range(len(type_x))]
            idx, x_trees = 0, [None] * len(type_x)
            while len(still_running) > 0:
                finished, still_running = ray.wait(still_running)
                finished_res = ray.get(finished)
                for ret_all_i in finished_res:
                    if c_dict['with_output']:
                        gp.share_completed(idx+1, len(type_x))
                    x_trees[idx] = ret_all_i
                    idx += 1
        else:
            with futures.ProcessPoolExecutor(max_workers=maxworkers) as fpp:
                trees = {fpp.submit(tree_search_multip_single, data_ps,
                                    data_ps_diff, data_x, name_x, type_x,
                                    values_x, c_dict,
                                    c_dict['ft_depth'], m_i):
                         m_i for m_i in range(len(type_x))}
                for idx, val in enumerate(futures.as_completed(trees)):
                    if c_dict['with_output']:
                        gp.share_completed(idx, len(type_x))
                    x_trees.append(val.result())
        optimal_reward = np.empty(len(type_x))
        for idx, tree in enumerate(x_trees):
            optimal_reward[idx] = tree[1]
        max_i = np.argmax(optimal_reward)
        optimal_reward = optimal_reward[max_i]
        optimal_tree, obs_total = x_trees[max_i][0], x_trees[max_i][2]
    else:
        optimal_tree, optimal_reward, obs_total = tree_search(
            data_ps, data_ps_diff, data_x, name_x, type_x, values_x, c_dict,
            c_dict['ft_depth'])
    return optimal_tree, optimal_reward, obs_total


@ray.remote
def ray_tree_search_multip_single(data_ps, data_ps_diff, data_x, name_x,
                                  type_x, values_x, c_dict, treedepth, m_i):
    """Prepare function for Ray."""
    return tree_search_multip_single(data_ps, data_ps_diff, data_x, name_x,
                                     type_x, values_x, c_dict, treedepth, m_i)


def final_leaf_dict(leaf, left_right):
    """Generate a dictionary used in evaluating the policy tree.

    Parameters
    ----------
    leaf : List.
    left_right : string.

    Returns
    -------
    return_dic :dict.

    """
    if leaf[5] is None or leaf[6] is None or leaf[7] is None or (
            left_right is None):
        print(leaf)
        raise Exception('No valid entries in final leaf.')
    return_dic = {'x_name': leaf[5], 'x_type': leaf[6],
                  'cut-off or set': leaf[7], 'left or right': left_right}
    return return_dic
